"""
Problem 49: https://projecteuler.net/problem=49

Prime permutations

The arithmetic sequence, 1487, 4817, 8147, in which each of
the terms increases by 3330, is unusual in two ways:
(i) each of the three terms are prime,
(ii) each of the 4-digit numbers are permutations of one another.

There are no arithmetic sequences made up of three 1-, 2-, or 3-digit primes,
exhibiting this property, but there is one other 4-digit increasing sequence.

What 12-digit number do you form by concatenating the three terms in this sequence?
"""

# _*_ conding:UTF-8 _*_
'''
@author = Kuperain
@email = kuperain@aliyun.com
@IDE = VSCODE Python3.8.3
@creat_time = 2022/5/15
'''


N = 10000
Primes = [True]*N
Primes[0] = False
Primes[1] = False
for i in range(2, N):
    for k in range(i, (N-1)//i+1):
        Primes[i*k] = False


def solution1(d: int = 4) -> str:
    '''
    n1, (n1+n2)/2, n2, are all prime, and n1 < n2

    >>> print(solution1(1))  # 3-digit
    []
    >>> print(solution1(2))  # 3-digit
    []
    >>> print(solution1(3))  # 3-digit
    []
    >>> print(solution1(4))  # 4-digit
    [(1487, 4817, 8147), (2969, 6299, 9629)]
    '''

    from collections import Counter

    res = []

    for n1 in range(10**(d-1)+1, 10**d, 2):
        if Primes[n1]:
            for n2 in range(n1+4, 10**d, 2):
                mid = (n1+n2)//2
                if Primes[n2] and Primes[mid]:
                    if Counter(str(n1)) == Counter(str(mid)) == Counter(str(n2)):
                        res.append((n1, mid, n2))
    return res


def solution2(d: int = 4) -> str:
    '''
    n1, n2, 2*n2-n1, are all prime, and n1 < n2

    >>> assert not solution2(1)  # 3-digit
    >>> assert not solution2(2)  # 3-digit
    >>> assert not solution2(3)  # 3-digit
    >>> assert solution2(4) == {(2969, 6299, 9629), (1487, 4817, 8147)}  # 4-digit
    '''
    res = set()

    import itertools
    digits = itertools.combinations('0123456789'*d, d)
    # print(list(nums))
    for item in digits:
        nums = itertools.permutations(item)
        nums = set(map(lambda x: ''.join(x), nums))
        nums = sorted([int(ns)
                      for ns in nums if ns[0] != '0' and Primes[int(ns)]])
        size = len(nums)
        if size < 3:
            continue
        # print(nums)
        for i in range(size):
            for j in range(i+1, size):
                mid = (nums[i]+nums[j])//2
                if mid in nums:
                    res.add((nums[i], mid, nums[j]))
    return res


if __name__ == "__main__":
    import doctest
    doctest.testmod(verbose=False)

    print(solution1())
    # [(1487, 4817, 8147), (2969, 6299, 9629)]

    print(solution2())
